Skip to content

Fix UE8M0 packing: mask mantissa bits when extracting fp32 exponents#35

Open
yhyang201 wants to merge 1 commit into
sgl-project:mainfrom
yhyang201:fix/ue8m0-packing-mantissa-leak
Open

Fix UE8M0 packing: mask mantissa bits when extracting fp32 exponents#35
yhyang201 wants to merge 1 commit into
sgl-project:mainfrom
yhyang201:fix/ue8m0-packing-mantissa-leak

Conversation

@yhyang201
Copy link
Copy Markdown

Same fix as deepseek-ai#337.

transpose_and_pack_fp32_into_ue8m0 and pack_fp32_into_ue8m0 pack 4 fp32
exponent bytes into one uint32 using shift-and-OR:

packed |= (values[0] >> 23u);   // byte 0: OK
packed |= (values[1] >> 15u);   // byte 1: mantissa leaks into byte 0
packed |= (values[2] >>  7u);   // byte 2: mantissa leaks into bytes 0-1
packed |= (values[3] <<  1u);   // byte 3: mantissa leaks into bytes 0-2

When fp32 scale factors have non-zero mantissa (not exact powers of two),
mantissa bits leak into adjacent exponent fields. For example, packing four
copies of 0.006975 (exp=0x77) produces 0x77ffffff instead of 0x77777777.

This causes NaN in fp8_einsum when callers pass plain float32 scales
(e.g. sglang's activation quantization for DeepSeek-V4 on Blackwell).

The torch fallback get_mn_major_tma_aligned_packed_ue8m0_tensor_torch is
correct (uses >> 23 then .to(kUInt8) which naturally truncates), but the
CUDA kernel path doesn't truncate.

Fix: extract each exponent with (val >> 23) & 0xFF and shift to the correct
byte position, in a shared pack_4_fp32_exponents helper.

Reproduction

import torch
from deep_gemm.utils.layout import get_mn_major_tma_aligned_packed_ue8m0_tensor

# Packing bug: non-power-of-2 fp32 value gets mantissa leaked into packed result
sf = torch.full((4, 8), 0.006975, device="cuda", dtype=torch.float32)
packed = get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
assert packed[0, 0].item() == 0x77777777, f"got {hex(packed[0, 0].item())}"  # fails: 0x77ffffff

# Power-of-2 value works because mantissa is zero
sf2 = torch.full((4, 8), 0.0078125, device="cuda", dtype=torch.float32)
packed2 = get_mn_major_tma_aligned_packed_ue8m0_tensor(sf2)
assert packed2[0, 0].item() == 0x78787878  # passes

The transpose_and_pack_fp32_into_ue8m0 and pack_fp32_into_ue8m0 kernels
pack 4 fp32 exponent bytes into one uint32 using a shift-and-OR trick
that relies on mantissa bits being zero. When the input fp32 scale
factors are not exact powers of two (i.e. have non-zero mantissa),
mantissa bits leak into adjacent exponent fields, producing garbage
values like 0xFF (Inf exponent).

This causes NaN output in fp8_einsum when callers pass plain float32
scale factors (e.g. sglang's sglang_per_token_group_quant_fp8), because
the kernel internally converts fp32 scales to ue8m0 via these functions.

Fix: extract each exponent with (val >> 23) & 0xFF and shift to the
correct byte position, extracted into a shared helper function.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant